{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "05_03_VGG_CIFAR10.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/OUCTheoryGroup/colab_demo/blob/master/05_03_VGG_CIFAR10.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xVOKImkp77k4",
        "colab_type": "text"
      },
      "source": [
        "## 使用 VGG16 对 CIFAR10 分类\n",
        "\n",
        "VGG是由Simonyan 和Zisserman在文献《Very Deep Convolutional Networks for Large Scale Image Recognition》中提出卷积神经网络模型，其名称来源于作者所在的牛津大学视觉几何组(Visual Geometry Group)的缩写。\n",
        "\n",
        "该模型参加2014年的 ImageNet图像分类与定位挑战赛，取得了优异成绩：在分类任务上排名第二，在定位任务上排名第一。\n",
        "\n",
        "VGG16的网络结构如下图所示：\n",
        "\n",
        "![VGG16示意图](https://gaopursuit.oss-cn-beijing.aliyuncs.com/202003/20200229111521.jpg)\n",
        "\n",
        "16层网络的结节信息如下：\n",
        "- 01：Convolution using 64 filters\n",
        "- 02: Convolution using 64 filters + Max pooling\n",
        "- 03: Convolution using 128 filters\n",
        "- 04: Convolution using 128 filters + Max pooling\n",
        "- 05: Convolution using 256 filters\n",
        "- 06: Convolution using 256 filters\n",
        "- 07: Convolution using 256 filters + Max pooling\n",
        "- 08: Convolution using 512 filters\n",
        "- 09: Convolution using 512 filters\n",
        "- 10: Convolution using 512 filters + Max pooling\n",
        "- 11: Convolution using 512 filters\n",
        "- 12: Convolution using 512 filters\n",
        "- 13: Convolution using 512 filters + Max pooling\n",
        "- 14: Fully connected with 4096 nodes\n",
        "- 15: Fully connected with 4096 nodes\n",
        "- 16: Softmax\n",
        "\n",
        "### 1. 定义 dataloader\n",
        "\n",
        "**需要注意的是，这里的 transform，dataloader 和之前定义的有所不同，大家自己体会。**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Eea65R4Y-SDp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "# 使用GPU训练，可以在菜单 \"代码执行工具\" -> \"更改运行时类型\" 里进行设置\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "transform_train = transforms.Compose([\n",
        "    transforms.RandomCrop(32, padding=4),\n",
        "    transforms.RandomHorizontalFlip(),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])\n",
        "\n",
        "transform_test = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])\n",
        "\n",
        "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,  download=True, transform=transform_train)\n",
        "testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
        "\n",
        "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)\n",
        "testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)\n",
        "\n",
        "classes = ('plane', 'car', 'bird', 'cat',\n",
        "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k8ZSDhEOAWCM",
        "colab_type": "text"
      },
      "source": [
        "### 2. VGG 网络定义\n",
        "\n",
        "下面定义VGG网络，参数太多，我手动改简单了些~~~\n",
        "\n",
        "现在的结构基本上是：\n",
        "\n",
        "64 conv, maxpooling,\n",
        "\n",
        "128 conv, maxpooling,\n",
        "\n",
        "256 conv, 256 conv, maxpooling,\n",
        "\n",
        "512 conv, 512 conv, maxpooling,\n",
        "\n",
        "512 conv, 512 conv, maxpooling,\n",
        "\n",
        "softmax \n",
        "\n",
        "可能有同学要问，为什么这么设置？\n",
        "\n",
        "其实不为什么，就是觉得对称，我自己随便改的。。。\n",
        "\n",
        "下面是模型的实现代码："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oETMV_drAYXG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class VGG(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(VGG, self).__init__()\n",
        "        self.cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']\n",
        "        self.features = self._make_layers(cfg)\n",
        "        self.classifier = nn.Linear(2048, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = self.features(x)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        out = self.classifier(out)\n",
        "        return out\n",
        "\n",
        "    def _make_layers(self, cfg):\n",
        "        layers = []\n",
        "        in_channels = 3\n",
        "        for x in cfg:\n",
        "            if x == 'M':\n",
        "                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n",
        "            else:\n",
        "                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),\n",
        "                           nn.BatchNorm2d(x),\n",
        "                           nn.ReLU(inplace=True)]\n",
        "                in_channels = x\n",
        "        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]\n",
        "        return nn.Sequential(*layers)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RRGlFA9NBLOH",
        "colab_type": "text"
      },
      "source": [
        "初始化网络，根据实际需要，修改分类层。因为 tiny-imagenet 是对200类图像分类，这里把输出修改为200。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PQbDShF-BGxk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 网络放到GPU上\n",
        "net = VGG().to(device)\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(net.parameters(), lr=0.001)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HBSMhhMsBn9R",
        "colab_type": "text"
      },
      "source": [
        "### 3. 网络训练\n",
        "\n",
        "训练的代码和以前是完全一样的："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8_dYGlrYBrLv",
        "colab_type": "code",
        "outputId": "ee872455-f69a-4317-b929-3473985fe93e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 714
        }
      },
      "source": [
        "for epoch in range(10):  # 重复多轮训练\n",
        "    for i, (inputs, labels) in enumerate(trainloader):\n",
        "        inputs = inputs.to(device)\n",
        "        labels = labels.to(device)\n",
        "        # 优化器梯度归零\n",
        "        optimizer.zero_grad()\n",
        "        # 正向传播 +　反向传播 + 优化 \n",
        "        outputs = net(inputs)\n",
        "        loss = criterion(outputs, labels)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        # 输出统计信息\n",
        "        if i % 100 == 0:   \n",
        "            print('Epoch: %d Minibatch: %5d loss: %.3f' %(epoch + 1, i + 1, loss.item()))\n",
        "\n",
        "print('Finished Training')"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 1 Minibatch:     1 loss: 2.453\n",
            "Epoch: 1 Minibatch:   101 loss: 1.819\n",
            "Epoch: 1 Minibatch:   201 loss: 1.383\n",
            "Epoch: 1 Minibatch:   301 loss: 1.208\n",
            "Epoch: 2 Minibatch:     1 loss: 1.025\n",
            "Epoch: 2 Minibatch:   101 loss: 0.965\n",
            "Epoch: 2 Minibatch:   201 loss: 0.808\n",
            "Epoch: 2 Minibatch:   301 loss: 0.728\n",
            "Epoch: 3 Minibatch:     1 loss: 0.737\n",
            "Epoch: 3 Minibatch:   101 loss: 0.820\n",
            "Epoch: 3 Minibatch:   201 loss: 0.909\n",
            "Epoch: 3 Minibatch:   301 loss: 0.711\n",
            "Epoch: 4 Minibatch:     1 loss: 0.604\n",
            "Epoch: 4 Minibatch:   101 loss: 0.603\n",
            "Epoch: 4 Minibatch:   201 loss: 0.640\n",
            "Epoch: 4 Minibatch:   301 loss: 0.740\n",
            "Epoch: 5 Minibatch:     1 loss: 0.526\n",
            "Epoch: 5 Minibatch:   101 loss: 0.620\n",
            "Epoch: 5 Minibatch:   201 loss: 0.335\n",
            "Epoch: 5 Minibatch:   301 loss: 0.620\n",
            "Epoch: 6 Minibatch:     1 loss: 0.589\n",
            "Epoch: 6 Minibatch:   101 loss: 0.631\n",
            "Epoch: 6 Minibatch:   201 loss: 0.375\n",
            "Epoch: 6 Minibatch:   301 loss: 0.489\n",
            "Epoch: 7 Minibatch:     1 loss: 0.463\n",
            "Epoch: 7 Minibatch:   101 loss: 0.352\n",
            "Epoch: 7 Minibatch:   201 loss: 0.376\n",
            "Epoch: 7 Minibatch:   301 loss: 0.299\n",
            "Epoch: 8 Minibatch:     1 loss: 0.423\n",
            "Epoch: 8 Minibatch:   101 loss: 0.281\n",
            "Epoch: 8 Minibatch:   201 loss: 0.380\n",
            "Epoch: 8 Minibatch:   301 loss: 0.399\n",
            "Epoch: 9 Minibatch:     1 loss: 0.271\n",
            "Epoch: 9 Minibatch:   101 loss: 0.331\n",
            "Epoch: 9 Minibatch:   201 loss: 0.281\n",
            "Epoch: 9 Minibatch:   301 loss: 0.401\n",
            "Epoch: 10 Minibatch:     1 loss: 0.264\n",
            "Epoch: 10 Minibatch:   101 loss: 0.330\n",
            "Epoch: 10 Minibatch:   201 loss: 0.343\n",
            "Epoch: 10 Minibatch:   301 loss: 0.388\n",
            "Finished Training\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3-0fkETmhE_e",
        "colab_type": "text"
      },
      "source": [
        "# 4. 测试验证准确率：\n",
        "\n",
        "测试的代码和之前也是完全一样的。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dlwerDKmhiFw",
        "colab_type": "code",
        "outputId": "37717883-9523-4cab-eac6-7c34a3181196",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "correct = 0\n",
        "total = 0\n",
        "\n",
        "for data in testloader:\n",
        "    images, labels = data\n",
        "    images, labels = images.to(device), labels.to(device)\n",
        "    outputs = net(images)\n",
        "    _, predicted = torch.max(outputs.data, 1)\n",
        "    total += labels.size(0)\n",
        "    correct += (predicted == labels).sum().item()\n",
        "\n",
        "print('Accuracy of the network on the 10000 test images: %.2f %%' % (\n",
        "    100 * correct / total))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Accuracy of the network on the 10000 test images: 84.92 %\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I33PhI7nPC5f",
        "colab_type": "text"
      },
      "source": [
        "可以看到，使用一个简化版的 VGG 网络，就能够显著地将准确率由 64%，提升到 84.92%\n",
        "\n",
        "使用哪些技术可以更进一步的提升性能呢？我们在后边的教程中会进一步学习。"
      ]
    }
  ]
}
